{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "DfPPQ6ztJhv4"
   },
   "source": [
    "# Mask R-CNN for Bin Picking\n",
    "\n",
    "This notebook is adopted from the [TorchVision 0.3 Object Detection finetuning tutorial](https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html).  We will be finetuning a pre-trained [Mask R-CNN](https://arxiv.org/abs/1703.06870) model on a dataset generated from our \"clutter generator\" script.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "DBIoe_tHTQgV"
   },
   "outputs": [],
   "source": [
    "# Imports\n",
    "import fnmatch\n",
    "import json\n",
    "import matplotlib.pyplot as plt\n",
    "import multiprocessing\n",
    "import numpy as np\n",
    "import os\n",
    "from PIL import Image\n",
    "from IPython.display import display\n",
    "\n",
    "import torch\n",
    "import torch.utils.data\n",
    "\n",
    "ycb = [\n",
    "    \"003_cracker_box.sdf\",\n",
    "    \"004_sugar_box.sdf\",\n",
    "    \"005_tomato_soup_can.sdf\",\n",
    "    \"006_mustard_bottle.sdf\",\n",
    "    \"009_gelatin_box.sdf\",\n",
    "    \"010_potted_meat_can.sdf\",\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "XwyE5A8DGtct"
   },
   "source": [
    "# Download our bin-picking model\n",
    "\n",
    "And a small set of images for testing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "_DgAgqauIET9"
   },
   "outputs": [],
   "source": [
    "dataset_path = \"clutter_maskrcnn_data\"\n",
    "if not os.path.exists(dataset_path):\n",
    "    !wget https://groups.csail.mit.edu/locomotion/clutter_maskrcnn_test.zip .\n",
    "    !unzip -q clutter_maskrcnn_test.zip\n",
    "\n",
    "num_images = len(fnmatch.filter(os.listdir(dataset_path), \"*.png\"))\n",
    "\n",
    "\n",
    "def open_image(idx):\n",
    "    filename = os.path.join(dataset_path, f\"{idx:05d}.png\")\n",
    "    return Image.open(filename).convert(\"RGB\")\n",
    "\n",
    "\n",
    "model_file = \"clutter_maskrcnn_model.pt\"\n",
    "if not os.path.exists(model_file):\n",
    "    !wget https://groups.csail.mit.edu/locomotion/clutter_maskrcnn_model.pt ."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xA8sBvuHNNH1"
   },
   "source": [
    "# Load the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "vUJXn15pGzRj"
   },
   "outputs": [],
   "source": [
    "import torchvision\n",
    "from torchvision.models.detection.faster_rcnn import FastRCNNPredictor\n",
    "from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor\n",
    "from torchvision.models.detection import MaskRCNN_ResNet50_FPN_Weights\n",
    "import torchvision.transforms.functional as Tf\n",
    "\n",
    "\n",
    "def get_instance_segmentation_model(num_classes):\n",
    "    # load an instance segmentation model pre-trained on COCO\n",
    "    model = torchvision.models.detection.maskrcnn_resnet50_fpn(\n",
    "        weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT\n",
    "    )\n",
    "\n",
    "    # get the number of input features for the classifier\n",
    "    in_features = model.roi_heads.box_predictor.cls_score.in_features\n",
    "    # replace the pre-trained head with a new one\n",
    "    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)\n",
    "\n",
    "    # now get the number of input features for the mask classifier\n",
    "    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels\n",
    "    hidden_layer = 256\n",
    "    # and replace the mask predictor with a new one\n",
    "    model.roi_heads.mask_predictor = MaskRCNNPredictor(\n",
    "        in_features_mask, hidden_layer, num_classes\n",
    "    )\n",
    "\n",
    "    return model\n",
    "\n",
    "\n",
    "num_classes = len(ycb) + 1\n",
    "model = get_instance_segmentation_model(num_classes)\n",
    "device = (\n",
    "    torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    ")\n",
    "model.load_state_dict(\n",
    "    torch.load(\"clutter_maskrcnn_model.pt\", map_location=device)\n",
    ")\n",
    "model.eval()\n",
    "\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Z6mYGFLxkO8F"
   },
   "source": [
    "# Evaluate the network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "YHwIdxH76uPj"
   },
   "outputs": [],
   "source": [
    "# pick one image from the test set (choose between 9950 and 9999)\n",
    "img = open_image(9952)\n",
    "\n",
    "with torch.no_grad():\n",
    "    prediction = model([Tf.to_tensor(img).to(device)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "DmN602iKsuey"
   },
   "source": [
    "Printing the prediction shows that we have a list of dictionaries. Each element\n",
    "of the list corresponds to a different image; since we have a single image,\n",
    "there is a single dictionary in the list. The dictionary contains the\n",
    "predictions for the image we passed. In this case, we can see that it contains\n",
    "`boxes`, `labels`, `masks` and `scores` as fields."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Lkmb3qUu6zw3"
   },
   "outputs": [],
   "source": [
    "prediction"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "RwT21rzotFbH"
   },
   "source": [
    "Let's inspect the image and the predicted segmentation masks.\n",
    "\n",
    "For that, we need to convert the image, which has been rescaled to 0-1 and had the channels flipped so that we have it in `[C, H, W]` format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "bpqN9t1u7B2J"
   },
   "outputs": [],
   "source": [
    "img"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "M58J3O9OtT1G"
   },
   "source": [
    "And let's now visualize the top predicted segmentation mask. The masks are predicted as `[N, 1, H, W]`, where `N` is the number of predictions, and are probability maps between 0-1."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "5v5S3bm07SO1"
   },
   "outputs": [],
   "source": [
    "N = prediction[0][\"masks\"].shape[0]\n",
    "fig, ax = plt.subplots(N, 1, figsize=(15, 15))\n",
    "for n in range(prediction[0][\"masks\"].shape[0]):\n",
    "    ax[n].imshow(\n",
    "        np.asarray(\n",
    "            Image.fromarray(\n",
    "                prediction[0][\"masks\"][n, 0].mul(255).byte().cpu().numpy()\n",
    "            )\n",
    "        )\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "z9QAeX9HkDTx"
   },
   "source": [
    "# Plot the object detections"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Z08keVFkvtPh"
   },
   "outputs": [],
   "source": [
    "import matplotlib.patches as patches\n",
    "import random\n",
    "\n",
    "\n",
    "def plot_prediction():\n",
    "    img_np = np.array(img)\n",
    "    fig, ax = plt.subplots(1, figsize=(12, 9))\n",
    "    ax.imshow(img_np)\n",
    "\n",
    "    cmap = plt.get_cmap(\"tab20b\")\n",
    "    colors = [cmap(i) for i in np.linspace(0, 1, 20)]\n",
    "\n",
    "    num_instances = prediction[0][\"boxes\"].shape[0]\n",
    "    bbox_colors = random.sample(colors, num_instances)\n",
    "    boxes = prediction[0][\"boxes\"].cpu().numpy()\n",
    "    labels = prediction[0][\"labels\"].cpu().numpy()\n",
    "\n",
    "    for i in range(num_instances):\n",
    "        color = bbox_colors[i]\n",
    "        bb = boxes[i, :]\n",
    "        bbox = patches.Rectangle(\n",
    "            (bb[0], bb[1]),\n",
    "            bb[2] - bb[0],\n",
    "            bb[3] - bb[1],\n",
    "            linewidth=2,\n",
    "            edgecolor=color,\n",
    "            facecolor=\"none\",\n",
    "        )\n",
    "        ax.add_patch(bbox)\n",
    "        plt.text(\n",
    "            bb[0],\n",
    "            bb[0],\n",
    "            s=ycb[labels[i]],\n",
    "            color=\"white\",\n",
    "            verticalalignment=\"top\",\n",
    "            bbox={\"color\": color, \"pad\": 0},\n",
    "        )\n",
    "    plt.axis(\"off\")\n",
    "\n",
    "\n",
    "plot_prediction()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HIfmykN-t7XG"
   },
   "source": [
    "# Visualize the region proposals \n",
    "\n",
    "Let's visualize some of the intermediate results of the networks.\n",
    "\n",
    "TODO: would be very cool to put a slider on this so that we could slide through ALL of the boxes.  But my matplotlib non-interactive backend makes it too tricky!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "zBNqFb68td8N"
   },
   "outputs": [],
   "source": [
    "class Inspector:\n",
    "    \"\"\"A helper class from Kuni to be used for torch.nn.Module.register_forward_hook.\"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        self.x = None\n",
    "\n",
    "    def hook(self, module, input, output):\n",
    "        self.x = output\n",
    "\n",
    "\n",
    "inspector = Inspector()\n",
    "model.rpn.register_forward_hook(inspector.hook)\n",
    "\n",
    "with torch.no_grad():\n",
    "    prediction = model([Tf.to_tensor(img).to(device)])\n",
    "\n",
    "rpn_values = inspector.x\n",
    "\n",
    "\n",
    "img_np = np.array(img)\n",
    "plt.figure()\n",
    "fig, ax = plt.subplots(1, figsize=(12, 9))\n",
    "ax.imshow(img_np)\n",
    "\n",
    "cmap = plt.get_cmap(\"tab20b\")\n",
    "colors = [cmap(i) for i in np.linspace(0, 1, 20)]\n",
    "\n",
    "num_to_draw = 20\n",
    "bbox_colors = random.sample(colors, num_to_draw)\n",
    "boxes = rpn_values[0][0].cpu().numpy()\n",
    "print(\n",
    "    f\"Region proposals (drawing first {num_to_draw} out of {boxes.shape[0]})\"\n",
    ")\n",
    "\n",
    "for i in range(num_to_draw):\n",
    "    color = bbox_colors[i]\n",
    "    bb = boxes[i, :]\n",
    "    bbox = patches.Rectangle(\n",
    "        (bb[0], bb[1]),\n",
    "        bb[2] - bb[0],\n",
    "        bb[3] - bb[1],\n",
    "        linewidth=2,\n",
    "        edgecolor=color,\n",
    "        facecolor=\"none\",\n",
    "    )\n",
    "    ax.add_patch(bbox)\n",
    "plt.axis(\"off\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Try a few more images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# pick one image from the test set (choose between 9950 and 9999)\n",
    "img = open_image(9985)\n",
    "\n",
    "with torch.no_grad():\n",
    "    prediction = model([Tf.to_tensor(img).to(device)])\n",
    "\n",
    "plot_prediction()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "clutter_maskrcnn_inference.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3.10.6 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  },
  "vscode": {
   "interpreter": {
    "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
